{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:35:55.270587Z",
     "start_time": "2024-05-30T11:35:55.238520Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(9842:124849532983104,MainProcess):2024-05-30-20:13:52.545.26 [mindspore/run_check/_check_version.py:102] MindSpore version 2.2.14 and cuda version 11.2.146 does not match, CUDA version [['10.1', '11.1', '11.6']] are supported by MindSpore officially. Please refer to the installation guide for version matching information: https://www.mindspore.cn/install.\n",
      "/home/qd/anaconda3/envs/nlpcc/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.243 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from mindnlp.peft import (\n",
    "    get_peft_model, LNTuningConfig, TaskType)\n",
    "from mindnlp.dataset import load_dataset\n",
    "from mindnlp.engine import Trainer, TrainingArguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:16:23.660075Z",
     "start_time": "2024-05-30T11:16:23.642160Z"
    }
   },
   "outputs": [],
   "source": [
    "classes = [\"Unlabeled\", \"complaint\", \"no complaint\"]\n",
    "model_name_or_path = \"bigscience/bloomz-560m\"\n",
    "\n",
    "dataset_name = \"twitter_complaints\"\n",
    "max_length = 64\n",
    "lr = 1e-4\n",
    "num_epochs = 25\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:16:23.679058Z",
     "start_time": "2024-05-30T11:16:23.660075Z"
    }
   },
   "outputs": [],
   "source": [
    "peft_config = LNTuningConfig(task_type=TaskType.CAUSAL_LM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:16:30.746552Z",
     "start_time": "2024-05-30T11:16:23.679058Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/qd/anaconda3/envs/nlpcc/lib/python3.9/site-packages/datasets/load.py:1486: FutureWarning: The repository for ought/raft contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/ought/raft\n",
      "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
      "Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "dataset = load_dataset(\"ought/raft\", dataset_name, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:16:30.767334Z",
     "start_time": "2024-05-30T11:16:30.746552Z"
    }
   },
   "outputs": [],
   "source": [
    "dataset['train'] = dataset['train'].project(['Tweet text', 'Label'])\n",
    "dataset['test'] = dataset['test'].project(['Tweet text', 'Label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:16:30.782853Z",
     "start_time": "2024-05-30T11:16:30.767334Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Tweet text', 'Label']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[\"train\"].get_col_names()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:16:32.559485Z",
     "start_time": "2024-05-30T11:16:30.782853Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 3)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "tokenizer.eos_token_id, tokenizer.pad_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:16:32.580452Z",
     "start_time": "2024-05-30T11:16:32.562151Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_max_length = max(tokenizer(classes, return_length=True)[\"length\"])\n",
    "target_max_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:16:32.596253Z",
     "start_time": "2024-05-30T11:16:32.582624Z"
    }
   },
   "outputs": [],
   "source": [
    "def num_label_2_text_label(num_label):\n",
    "    return classes[num_label]\n",
    "\n",
    "\n",
    "def concate_text_label(texts, labels):\n",
    "    labels = f\"{str(labels)}</s>\"\n",
    "    text_formatted = f\"Tweet text : {texts} Label : {labels}\"\n",
    "    output = tokenizer(\n",
    "        text_formatted,\n",
    "        text_target=labels,\n",
    "        max_length=max_length,\n",
    "        padding=\"max_length\",\n",
    "        truncation=True\n",
    "    )\n",
    "\n",
    "    output[\"labels\"] = [-100 if x == tokenizer.pad_token_id else x for x in output['labels']]\n",
    "\n",
    "    return output[\"input_ids\"], output[\"attention_mask\"], output['labels']\n",
    "\n",
    "\n",
    "def datapipe(data):\n",
    "    data = data.map(\n",
    "        num_label_2_text_label,\n",
    "        input_columns=\"Label\",\n",
    "        output_columns=\"labels\",\n",
    "    )\n",
    "    data = data.map(\n",
    "        concate_text_label,\n",
    "        input_columns=[\"Tweet text\", \"labels\"],\n",
    "        output_columns=[\"input_ids\", \"attention_mask\", \"labels\"],\n",
    "    )\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[\"train\"] = datapipe(dataset[\"train\"])\n",
    "dataset['test'] = datapipe(dataset[\"test\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['input_ids', 'attention_mask', 'labels']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[\"train\"].get_col_names()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:26:34.625478Z",
     "start_time": "2024-05-30T11:24:03.440010Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 100,352 || all params: 559,314,944 || trainable%: 0.017941948642087417\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:37:58.231069Z",
     "start_time": "2024-05-30T11:37:58.215082Z"
    }
   },
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./output\",\n",
    "    num_train_epochs=num_epochs,\n",
    "    learning_rate=lr,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:38:24.771458Z",
     "start_time": "2024-05-30T11:38:24.752472Z"
    }
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataset=dataset[\"train\"],\n",
    "    eval_dataset=dataset[\"test\"],\n",
    "    args=training_args,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:39:58.807764Z",
     "start_time": "2024-05-30T11:38:31.756761Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                               \n",
      "  4%|▍         | 7/175 [00:17<00:50,  3.34it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.95366907119751, 'eval_runtime': 15.0169, 'eval_samples_per_second': 28.302, 'eval_steps_per_second': 3.596, 'epoch': 1.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      "  8%|▊         | 14/175 [00:36<02:19,  1.15it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.8654632568359375, 'eval_runtime': 14.9481, 'eval_samples_per_second': 28.432, 'eval_steps_per_second': 3.613, 'epoch': 2.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 12%|█▏        | 21/175 [00:54<02:20,  1.10it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.781277656555176, 'eval_runtime': 14.9663, 'eval_samples_per_second': 28.397, 'eval_steps_per_second': 3.608, 'epoch': 3.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 16%|█▌        | 28/175 [01:13<02:15,  1.09it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.702119827270508, 'eval_runtime': 14.9927, 'eval_samples_per_second': 28.347, 'eval_steps_per_second': 3.602, 'epoch': 4.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 20%|██        | 35/175 [01:32<02:08,  1.09it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.629334926605225, 'eval_runtime': 14.9047, 'eval_samples_per_second': 28.514, 'eval_steps_per_second': 3.623, 'epoch': 5.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 24%|██▍       | 42/175 [01:51<02:01,  1.09it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.564324855804443, 'eval_runtime': 15.0315, 'eval_samples_per_second': 28.274, 'eval_steps_per_second': 3.592, 'epoch': 6.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 28%|██▊       | 49/175 [02:09<01:56,  1.08it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.508313179016113, 'eval_runtime': 15.0053, 'eval_samples_per_second': 28.323, 'eval_steps_per_second': 3.599, 'epoch': 7.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 32%|███▏      | 56/175 [02:28<01:50,  1.07it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.4620137214660645, 'eval_runtime': 15.0438, 'eval_samples_per_second': 28.251, 'eval_steps_per_second': 3.59, 'epoch': 8.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 36%|███▌      | 63/175 [02:48<01:44,  1.08it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.425300121307373, 'eval_runtime': 15.2347, 'eval_samples_per_second': 27.897, 'eval_steps_per_second': 3.545, 'epoch': 9.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 40%|████      | 70/175 [03:07<01:38,  1.07it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.39694356918335, 'eval_runtime': 15.0218, 'eval_samples_per_second': 28.292, 'eval_steps_per_second': 3.595, 'epoch': 10.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 44%|████▍     | 77/175 [03:25<01:31,  1.07it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.3746113777160645, 'eval_runtime': 14.8727, 'eval_samples_per_second': 28.576, 'eval_steps_per_second': 3.631, 'epoch': 11.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 48%|████▊     | 84/175 [03:44<01:24,  1.08it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.355413436889648, 'eval_runtime': 14.8845, 'eval_samples_per_second': 28.553, 'eval_steps_per_second': 3.628, 'epoch': 12.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 52%|█████▏    | 91/175 [04:03<01:17,  1.08it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.33675479888916, 'eval_runtime': 15.0295, 'eval_samples_per_second': 28.278, 'eval_steps_per_second': 3.593, 'epoch': 13.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 56%|█████▌    | 98/175 [04:22<01:11,  1.08it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.317072868347168, 'eval_runtime': 15.1317, 'eval_samples_per_second': 28.087, 'eval_steps_per_second': 3.569, 'epoch': 14.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      " 60%|██████    | 105/175 [04:41<01:05,  1.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.296142578125, 'eval_runtime': 15.1123, 'eval_samples_per_second': 28.123, 'eval_steps_per_second': 3.573, 'epoch': 15.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      " 64%|██████▍   | 112/175 [05:00<00:58,  1.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.274670124053955, 'eval_runtime': 14.923, 'eval_samples_per_second': 28.479, 'eval_steps_per_second': 3.619, 'epoch': 16.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      " 68%|██████▊   | 119/175 [05:18<00:51,  1.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.2537055015563965, 'eval_runtime': 14.7471, 'eval_samples_per_second': 28.819, 'eval_steps_per_second': 3.662, 'epoch': 17.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      " 72%|███████▏  | 126/175 [05:38<00:47,  1.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.234174728393555, 'eval_runtime': 15.0733, 'eval_samples_per_second': 28.195, 'eval_steps_per_second': 3.582, 'epoch': 18.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      " 76%|███████▌  | 133/175 [05:56<00:39,  1.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.216706275939941, 'eval_runtime': 14.9157, 'eval_samples_per_second': 28.493, 'eval_steps_per_second': 3.62, 'epoch': 19.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      " 80%|████████  | 140/175 [06:15<00:32,  1.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.201669692993164, 'eval_runtime': 14.6981, 'eval_samples_per_second': 28.915, 'eval_steps_per_second': 3.674, 'epoch': 20.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      " 84%|████████▍ | 147/175 [06:34<00:25,  1.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.189225673675537, 'eval_runtime': 15.2096, 'eval_samples_per_second': 27.943, 'eval_steps_per_second': 3.55, 'epoch': 21.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      " 88%|████████▊ | 154/175 [06:53<00:19,  1.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.179466724395752, 'eval_runtime': 14.782, 'eval_samples_per_second': 28.751, 'eval_steps_per_second': 3.653, 'epoch': 22.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      " 92%|█████████▏| 161/175 [07:12<00:12,  1.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.172454357147217, 'eval_runtime': 15.5205, 'eval_samples_per_second': 27.383, 'eval_steps_per_second': 3.479, 'epoch': 23.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      " 96%|█████████▌| 168/175 [07:31<00:06,  1.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.168280124664307, 'eval_runtime': 14.7836, 'eval_samples_per_second': 28.748, 'eval_steps_per_second': 3.653, 'epoch': 24.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                 \n",
      "100%|██████████| 175/175 [07:50<00:00,  1.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 4.167085647583008, 'eval_runtime': 15.1478, 'eval_samples_per_second': 28.057, 'eval_steps_per_second': 3.565, 'epoch': 25.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 175/175 [07:52<00:00,  2.70s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'train_runtime': 472.5518, 'train_samples_per_second': 2.963, 'train_steps_per_second': 0.37, 'train_loss': 6.5912158203125, 'epoch': 25.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=175, training_loss=6.5912158203125, metrics={'train_runtime': 472.5518, 'train_samples_per_second': 2.963, 'train_steps_per_second': 0.37, 'train_loss': 6.5912158203125, 'epoch': 25.0})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:40:14.754285Z",
     "start_time": "2024-05-30T11:40:14.720220Z"
    }
   },
   "outputs": [],
   "source": [
    "# 'twitter_complaints_bigscience_bloomz-560m_LN_TUNING_CAUSAL_LM'\n",
    "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "model.save_pretrained(peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-30T11:40:31.242284Z",
     "start_time": "2024-05-30T11:40:31.235289Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'twitter_complaints_bigscience_bloomz-560m_LN_TUNING_CAUSAL_LM'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "peft_model_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindnlp.peft import PeftModel, PeftConfig\n",
    "\n",
    "# load the LNTuning config\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "# load the base LM\n",
    "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
    "# merge LNTuning weights into the base LM\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"Tweet text : I am very angry with the service provided by the company Label :\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': Tensor(shape=[1, 16], dtype=Int64, value=\n",
       "[[227985,   5484,    915 ...  16333,  77658,    915]]), 'attention_mask': Tensor(shape=[1, 16], dtype=Int64, value=\n",
       "[[1, 1, 1 ... 1, 1, 1]])}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output = tokenizer(text, return_tensors='ms')\n",
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model.generate(\n",
    "        input_ids=output[\"input_ids\"], attention_mask=output[\"attention_mask\"], max_new_tokens=3, eos_token_id=3\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Tweet text : I am very angry with the service provided by the company Label : it is not'"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(out[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
